package easy;

import java.util.ArrayList;
import java.util.List;

public class Solution_1260 {

    public List<List<Integer>> shiftGrid(int[][] grid, int k) {
        int count = grid.length * grid[0].length;
        int[] array = new int[count];
        int index = k;
        for (int[] row : grid) {
            for (int cell : row) {
                array[index % count] = cell;
                index++;
            }
        }
        int colLength = grid[0].length;
        List<Integer> col = null;
        List<List<Integer>> list = new ArrayList<>(grid.length);
        for (int i = 0; i < count; i++) {
            if (i % colLength == 0) {
                if (col != null) {
                    list.add(col);
                }
                col = new ArrayList<>(colLength);
            }
            col.add(array[i]);
        }
        list.add(col);
        return list;
    }

    public static void main(String[] args) {
        Solution_1260 model = new Solution_1260();
        System.out.println(model.shiftGrid(new int[][]{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, 1));
    }
}
